{
 "cells": [
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:44:06.320302Z",
     "start_time": "2025-09-02T02:44:06.309698Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import gymnasium as gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('CartPole-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.01929177,  0.01057963, -0.03562201, -0.01026118], dtype=float32)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 13
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:44:06.384254Z",
     "start_time": "2025-09-02T02:44:06.338061Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ],
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAF7CAYAAAD4/3BBAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjUsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvWftoOwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAJh5JREFUeJzt3Xt0VNXdxvHf5EpCSNIEkhBJEAW5ByxgSL2XSASkUuN6vVCMlgVLGlhCFDEWQbDLUOyqt2L4oyq2BVGsQEFBYxAoEgEjVIiCQlGg5ILwJiEouc15196+M81IgFw5ezLfz1rHkzNnZ2bPNmSe7Ms5DsuyLAEAADCIn90VAAAA+DECCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADAOAQUAABgHAIKAAAwjq0BZcmSJXL55ZdLp06dJDk5WXbu3GlndQAAgK8HlDfeeEOysrJk/vz58umnn8qQIUMkLS1NysrK7KoSAAAwhMOumwWqHpMRI0bIn/70J33sdDolISFBZsyYIY899pgdVQIAAIYIsONFa2pqpLCwULKzs92P+fn5SWpqqhQUFJxTvrq6Wm8uKsycOnVKoqOjxeFwXLJ6AwCAllN9IqdPn5b4+Hj9uW9cQPn222+lvr5eYmNjPR5Xx/v37z+nfE5OjixYsOAS1hAAALSXo0ePSo8ePcwLKM2lelrUfBWXiooKSUxM1G8wPDzc1roBAICmqays1NM5unTpctGytgSUrl27ir+/v5SWlno8ro7j4uLOKR8cHKy3H1PhhIACAIB3acr0DFtW8QQFBcmwYcMkPz/fY16JOk5JSbGjSgAAwCC2DfGoIZuMjAwZPny4XHPNNfLcc8/JmTNn5IEHHrCrSgAAwNcDyl133SUnTpyQefPmSUlJiQwdOlQ2btx4zsRZAADge2y7DkprJ9lEREToybLMQQEAoON9fnMvHgAAYBwCCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAOAQUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHEIKAAAwDgEFAAAYBwCCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAACAjh9QnnzySXE4HB5bv3793OfPnj0rmZmZEh0dLWFhYZKeni6lpaVtXQ0AAODF2qUHZeDAgVJcXOzetm3b5j43a9YsWbdunaxatUq2bNkix48flzvuuKM9qgEAALxUQLs8aUCAxMXFnfN4RUWFvPzyy7JixQr5+c9/rh979dVXpX///vLxxx/LyJEj26M6AADAy7RLD8pXX30l8fHxcsUVV8jEiRPlyJEj+vHCwkKpra2V1NRUd1k1/JOYmCgFBQXnfb7q6mqprKz02AAAQMfV5gElOTlZli1bJhs3bpTc3Fw5fPiwXH/99XL69GkpKSmRoKAgiYyM9Pie2NhYfe58cnJyJCIiwr0lJCS0dbUBAEBHHuIZM2aM++ukpCQdWHr27ClvvvmmhISEtOg5s7OzJSsry32selAIKQAAdFztvsxY9ZZcddVVcvDgQT0vpaamRsrLyz3KqFU8jc1ZcQkODpbw8HCPDQAAdFztHlCqqqrk0KFD0r17dxk2bJgEBgZKfn6++/yBAwf0HJWUlJT2rgoAAPDVIZ5HHnlExo8fr4d11BLi+fPni7+/v9xzzz16/sjkyZP1cE1UVJTuCZkxY4YOJ6zgAQAA7RZQjh07psPIyZMnpVu3bnLdddfpJcTqa+XZZ58VPz8/fYE2tTonLS1NXnrppbauBgAA8GIOy7Is8TJqkqzqjVHXVWE+CgAAHe/zm3vxAAAA4xBQAACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAOAQUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHEIKAAAwDgEFAAAYBwCCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAAC8P6Bs3bpVxo8fL/Hx8eJwOGTNmjUe5y3Lknnz5kn37t0lJCREUlNT5auvvvIoc+rUKZk4caKEh4dLZGSkTJ48Waqqqlr/bgAAgG8GlDNnzsiQIUNkyZIljZ5fvHixvPDCC7J06VLZsWOHdO7cWdLS0uTs2bPuMiqcFBUVSV5enqxfv16HnqlTp7bunQAAgA7DYakuj5Z+s8Mhq1evlgkTJuhj9VSqZ+Xhhx+WRx55RD9WUVEhsbGxsmzZMrn77rvliy++kAEDBsiuXbtk+PDhuszGjRtl7NixcuzYMf39F1NZWSkRERH6uVUvDAAAMF9zPr/bdA7K4cOHpaSkRA/ruKiKJCcnS0FBgT5WezWs4woniirv5+ene1waU11drd9Uww0AAHRcbRpQVDhRVI9JQ+rYdU7tY2JiPM4HBARIVFSUu8yP5eTk6KDj2hISEtqy2gAAwDBesYonOztbdwe5tqNHj9pdJQAA4C0BJS4uTu9LS0s9HlfHrnNqX1ZW5nG+rq5Or+xxlfmx4OBgPVbVcAMAAB1XmwaUXr166ZCRn5/vfkzNF1FzS1JSUvSx2peXl0thYaG7zKZNm8TpdOq5KgAAAAHN/QZ1vZKDBw96TIzds2ePnkOSmJgoM2fOlN/97nfSp08fHVieeOIJvTLHtdKnf//+cuutt8qUKVP0UuTa2lqZPn26XuHTlBU8AACg42t2QPnkk0/k5ptvdh9nZWXpfUZGhl5K/Oijj+prpajrmqiekuuuu04vI+7UqZP7e5YvX65DyahRo/TqnfT0dH3tFAAAgFZfB8UuXAcFAADvY9t1UAAAANoCAQUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHEIKAAAwDgEFAAAYBwCCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAOAQUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgPcHlK1bt8r48eMlPj5eHA6HrFmzxuP8/fffrx9vuN16660eZU6dOiUTJ06U8PBwiYyMlMmTJ0tVVVXr3w0AAPDNgHLmzBkZMmSILFmy5LxlVCApLi52b6+//rrHeRVOioqKJC8vT9avX69Dz9SpU1v2DgAAQIcT0NxvGDNmjN4uJDg4WOLi4ho998UXX8jGjRtl165dMnz4cP3Yiy++KGPHjpU//OEPumcGAAD4tnaZg7J582aJiYmRvn37yrRp0+TkyZPucwUFBXpYxxVOlNTUVPHz85MdO3Y0+nzV1dVSWVnpsQEAgI6rzQOKGt75y1/+Ivn5+fL73/9etmzZontc6uvr9fmSkhIdXhoKCAiQqKgofa4xOTk5EhER4d4SEhLautoAAMCbh3gu5u6773Z/PXjwYElKSpIrr7xS96qMGjWqRc+ZnZ0tWVlZ7mPVg0JIAQCg42r3ZcZXXHGFdO3aVQ4ePKiP1dyUsrIyjzJ1dXV6Zc/55q2oOS1qxU/DDQAAdFztHlCOHTum56B0795dH6ekpEh5ebkUFha6y2zatEmcTqckJye3d3UAAEBHHOJR1ytx9YYohw8flj179ug5JGpbsGCBpKen696QQ4cOyaOPPiq9e/eWtLQ0Xb5///56nsqUKVNk6dKlUltbK9OnT9dDQ6zgAQAAisOyLKs5TaHmktx8883nPJ6RkSG5ubkyYcIE2b17t+4lUYFj9OjR8tRTT0lsbKy7rBrOUaFk3bp1evWOCjQvvPCChIWFNakOag6KmixbUVHBcA8AAF6iOZ/fzQ4oJiCgAADgfZrz+c29eAAAgHEIKAAAwDgEFAAAYBwCCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADA+28WCABt5Zt/rpDqqpMXLHPZiNulc9fES1YnAGYgoACwzeniL+X7/z1+wTIxA24UKzpBHA7HJasXAPsxxAPAaM76OhHxunuaAmglAgoAo1kqoJBPAJ9DQAFgNGd9LT0ogA8ioAAwvwcFgM8hoAAwmtNZJ5ZFDwrgawgoAIxGDwrgmwgoAIxmMQcF8EkEFADmLzMmnwA+h4ACwGis4gF8EwEFgPFzUJgkC/geAgoAL7iSLABfQ0ABYJuQ6B4icuF77Hx/6j9iOQkpgK8hoACwTXh8P5GL3ATwTNm/6UUBfBABBYBtHP6BdlcBgKEIKABs4xcQeJEBHgC+ioACwDZ+qgflIkM8AHwTAQWArT0oANAYAgoA2zAHBUCbBJScnBwZMWKEdOnSRWJiYmTChAly4MABjzJnz56VzMxMiY6OlrCwMElPT5fS0lKPMkeOHJFx48ZJaGiofp7Zs2dLXR2z9AFf4/APsLsKADpCQNmyZYsOHx9//LHk5eVJbW2tjB49Ws6cOeMuM2vWLFm3bp2sWrVKlz9+/Ljccccd7vP19fU6nNTU1Mj27dvltddek2XLlsm8efPa9p0B8I45KADQCIfVimtInzhxQveAqCByww03SEVFhXTr1k1WrFghd955py6zf/9+6d+/vxQUFMjIkSNlw4YNctttt+ngEhsbq8ssXbpU5syZo58vKCjooq9bWVkpERER+vXCw8NbWn0ANjtbeUL2vTH/ohdiGzLpGQkKjbhk9QLQPprz+d2qOSjqBZSoqCi9Lyws1L0qqamp7jL9+vWTxMREHVAUtR88eLA7nChpaWm60kVFRY2+TnV1tT7fcAPg/ehBAdDmAcXpdMrMmTPl2muvlUGDBunHSkpKdA9IZGSkR1kVRtQ5V5mG4cR13nXufHNfVOJybQkJCS2tNgCDMEkWQJsHFDUXZd++fbJy5Uppb9nZ2bq3xrUdPXq03V8TwCVaZsxlUAA0okVT6KdPny7r16+XrVu3So8e6mZfP4iLi9OTX8vLyz16UdQqHnXOVWbnzp0ez+da5eMq82PBwcF6A+CjQzyWJWq6nIOLugE+o1k9KOoXhAonq1evlk2bNkmvXr08zg8bNkwCAwMlPz/f/ZhahqyWFaekpOhjtd+7d6+UlZW5y6gVQWqyzIABA1r/jgB0ONwsEPA9Ac0d1lErdNauXauvheKaM6LmhYSEhOj95MmTJSsrS0+cVaFjxowZOpSoFTyKWpasgsikSZNk8eLF+jnmzp2rn5teEgCNsepr7a4CAJMDSm5urt7fdNNNHo+/+uqrcv/99+uvn332WfHz89MXaFOrb9QKnZdeesld1t/fXw8PTZs2TQeXzp07S0ZGhixcuLBt3hGADsdJQAF8Tquug2IXroMCdAzq10/hy5liXWQIp9+EORIWcwVzUAAvd8mugwIAl4Kzjh4UwNcQUAAYz6qvsbsKAC4xAgoA49GDAvgeAgoA4zm52zngcwgoAIznrGOIB/A1BBQAxuM6KIDvIaAAMB7XQQF8DwEFgPEIKIDvIaAAsFV0nx9ug3Eh3+7fdknqAsAcBBQAtgrq/N87n59P7fenL0ldAJiDgALAVn4BQXZXAYCBCCgAbOXwD7S7CgAMREABYCs/AgqARhBQANjKL4CAAuBcBBQAtnL4MwcFwLkIKABs5ecfYHcVABiIgALAVgzxAGgMAQWArZgkC6AxBBQAtuI6KAAaQ0ABYCt6UAA0hoACwFYO5qAAaAQBBYCt/PyauIrHstq7KgAMQkABYBuHw6H+06SyTmddu9cHgDkIKAC8glVfa3cVAFxCBBQAXsASJwEF8CkEFABewVlHQAF8CQEFgPksNcTDHBTAlxBQAHgFelAA30JAAeAFmIMC+BoCCgCvwCoewLcQUAB4BXpQAN/SrICSk5MjI0aMkC5dukhMTIxMmDBBDhw44FHmpptu0hdfarg9+OCDHmWOHDki48aNk9DQUP08s2fPlro6JsABOD8nk2QBn9LEa0z/YMuWLZKZmalDigoUjz/+uIwePVo+//xz6dy5s7vclClTZOHChe5jFURc6uvrdTiJi4uT7du3S3Fxsdx3330SGBgoTz/9dFu9LwAdDEM8gG9pVkDZuHGjx/GyZct0D0hhYaHccMMNHoFEBZDGvP/++zrQfPDBBxIbGytDhw6Vp556SubMmSNPPvmkBAVx63UAnizLYhUP4GNaNQeloqJC76OiojweX758uXTt2lUGDRok2dnZ8t1337nPFRQUyODBg3U4cUlLS5PKykopKipq9HWqq6v1+YYbgI7B4fCTwNCIi5arrjxxSeoDwMsDitPplJkzZ8q1116rg4jLvffeK3/729/kww8/1OHkr3/9q/zqV79yny8pKfEIJ4rrWJ0739yXiIgI95aQkNDSagMwjF9gkEQkJl24kOWU8m/+damqBMDbhngaUnNR9u3bJ9u2bfN4fOrUqe6vVU9J9+7dZdSoUXLo0CG58sorW/RaKuhkZWW5j1UPCiEF6Cgc4uff4l9FADqoFvWgTJ8+XdavX697SXr06HHBssnJyXp/8OBBvVdzU0pLSz3KuI7PN28lODhYwsPDPTYAHYVDHAQUAK0JKGqimgonq1evlk2bNkmvXr0u+j179uzRe9WToqSkpMjevXulrKzMXSYvL0+HjgEDBjSnOgA6AIdDxOFHQAHgKaC5wzorVqyQtWvX6muhuOaMqHkhISEhehhHnR87dqxER0fLZ599JrNmzdIrfJKSfhhjVsuSVRCZNGmSLF68WD/H3Llz9XOrnhIAvsYhfgGBdlcCgDf3oOTm5uqVO+pibKpHxLW98cYb+rxaIqyWD6sQ0q9fP3n44YclPT1d1q1b534Of39/PTyk9qo3RU2gVddBaXjdFAA+xCHMQQFwjoDmDvFciJq4qi7mdjE9e/aUd999tzkvDaDDcjDEA+Ac3IsHgL3ULTHoQQHwIwQUALZyqF9EzEEB8CMEFAA2Y4gHwLkIKAAMmCRLDwoATwQUAAb0oPjbXQkAhiGgALBdU5cZX2wlIYCOg4ACwFYOdSnZprAssZx17V0dAIYgoADwCpZY4qwnoAC+goACwDuoHpT6ertrAeASIaAA8A4M8QA+hYACwCuoCbIWQzyAzyCgAPASljidDPEAvoKAAsB7elAY4gF8BgEFgHdgiAfwKQQUAF6CHhTAlxBQAHgFy3KKk2XGgM8goADwDgzxAD6FgALAOzBJFvApBBQAtgsMjZCQ6B4XLFNX851UlRy6ZHUCYK+m3UIUAC6yBLi+NfND/AIlICRCRI6d/zXq66S66pTU1bWuF8Xf37/pNygEYBsCCoBWq62tlS5duojT6WzR98dEdpas/0mWG5J6XrDc35Yvl9/d+qC0xtdffy2XXXZZq54DQPsjoABoE6pno6UBpaa2VmprL94DYzmdre5BAeAdCCgAbOd0WlJX/0O4sSyR0prLpao+UkQcEupXKbHBh8Xf0bLwA8A7EVAA2M5pOd0BZW/VjfJtbQ+pcXbSASXQcVaOV/eRYeEb7a4mgEuIVTwAbFfvtESN8Hx2+kYdRqqdncUSf7HET2qsUDlRmyC7KseJk19ZgM/gXzsAI4Z4vqwaLP+pvkqHknM55GRtvOyrusGG2gGwAwEFgO2clmsOyoWW/7I0GPAlBBQARvSg1P7/HBQAUAgoAIwIKPV1BBQA/0VAAWDEEE+PoH9JbNC/1ULjRkpYEhFQJgM6f2RD7QAYH1Byc3MlKSlJwsPD9ZaSkiIbNmxwnz979qxkZmZKdHS0hIWFSXp6upSWlno8x5EjR2TcuHESGhoqMTExMnv2bC68BPg41YNiOWvk6i4fSEzQN3ppsYjqUXFKgKNah5OUiDUS4Ki1u6oATLwOSo8ePWTRokXSp08ffe+N1157TW6//XbZvXu3DBw4UGbNmiXvvPOOrFq1SiIiImT69Olyxx13yEcf/fBXj7pXhwoncXFxsn37dikuLpb77rtPAgMD5emnn26v9wjAcKrPZP+Rb2XtR/vVV3Ls7FVyui5aLHFImP//So9OX8paR73s/bfnHzwAOi6HpZJGK0RFRckzzzwjd955p3Tr1k1WrFihv1b2798v/fv3l4KCAhk5cqTubbntttvk+PHjEhsbq8ssXbpU5syZIydOnJCgoKAmvWZlZaUOQPfff3+TvwdA+1GXuH/55Zf1Hy6mmzhxonTu3NnuagA+qaamRpYtWyYVFRV6JKZdriSrekNUT8mZM2f0UE9hYaG+YVhqaqq7TL9+/SQxMdEdUNR+8ODB7nCipKWlybRp06SoqEiuvvrqRl+rurpabw0DijJp0iQ9lATAXmqY9pVXXvGKgHLPPffoP6YAXHpVVVU6oDRFswPK3r17dSBR801UOFi9erUMGDBA9uzZo3szIiPV/TP+S4WRkpIS/bXaNwwnrvOuc+eTk5MjCxYsOOfx4cOHXzSBAbg0fxV5i6FDh3I3Y8Amrg6GdlnF07dvXx1GduzYoXs+MjIy5PPPP5f2lJ2drbuDXNvRo0fb9fUAAIC9mt2DonpJevfurb8eNmyY7Nq1S55//nm566679F9R5eXlHr0oahWPmhSrqP3OnTs9ns+1ysdVpjHBwcF6AwAAvsGvLSbHqfkhKqyo1Tj5+fnucwcOHNDLitWQkKL2aoiorKzMXSYvL08P06hhIgAAgGb3oKihljFjxuiJr6dPn9YrdjZv3izvvfeeXlUzefJkycrK0it7VOiYMWOGDiVqgqwyevRoHUTU5NbFixfreSdz587V106hhwQAALQooKieD3XdEnX9EhVI1EXbVDi55ZZb9Plnn31W/Pz89AXaVK+KWqHz0ksvub/f399f1q9fr+euqOCilvqpOSwLFy5sTjUAAEAH1+rroNjBdR2UpqyjBtD+1PyzkJAQPeRrumPHjrGKB/CCz2/uxQMAAIxDQAEAAMYhoAAAAOMQUAAAgHFafC8eAHBRq/cmTJjgFZNkO3XqZHcVADQBAQVAqwUEBMjf//53u6sBoANhiAcAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHEIKAAAwDgEFAAAYBwCCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAAPDugJKbmytJSUkSHh6ut5SUFNmwYYP7/E033SQOh8Nje/DBBz2e48iRIzJu3DgJDQ2VmJgYmT17ttTV1bXdOwIAAF4voDmFe/ToIYsWLZI+ffqIZVny2muvye233y67d++WgQMH6jJTpkyRhQsXur9HBRGX+vp6HU7i4uJk+/btUlxcLPfdd58EBgbK008/3ZbvCwAAeDGHpZJGK0RFRckzzzwjkydP1j0oQ4cOleeee67Rsqq35bbbbpPjx49LbGysfmzp0qUyZ84cOXHihAQFBTXpNSsrKyUiIkIqKip0Tw4AADBfcz6/WzwHRfWGrFy5Us6cOaOHelyWL18uXbt2lUGDBkl2drZ899137nMFBQUyePBgdzhR0tLSdIWLiorO+1rV1dW6TMMNAAB0XM0a4lH27t2rA8nZs2clLCxMVq9eLQMGDNDn7r33XunZs6fEx8fLZ599pntGDhw4IG+//bY+X1JS4hFOFNexOnc+OTk5smDBguZWFQAA+EpA6du3r+zZs0d3z7z11luSkZEhW7Zs0SFl6tSp7nKqp6R79+4yatQoOXTokFx55ZUtrqTqicnKynIfqx6UhISEFj8fAAAwW7OHeNQ8kd69e8uwYcN0z8aQIUPk+eefb7RscnKy3h88eFDv1eTY0tJSjzKuY3XufIKDg90rh1wbAADouFp9HRSn06nniDRG9bQoqidFUUNDaoiorKzMXSYvL08HDtcwEQAAQEBzh1rGjBkjiYmJcvr0aVmxYoVs3rxZ3nvvPT2Mo47Hjh0r0dHReg7KrFmz5IYbbtDXTlFGjx6tg8ikSZNk8eLFet7J3LlzJTMzU/eSAAAANDugqJ4Pdd0Sdf0StUxIBQ8VTm655RY5evSofPDBB3qJsVrZo+aIpKen6wDi4u/vL+vXr5dp06bp3pTOnTvrOSwNr5sCAADQ6uug2IHroAAA4H0uyXVQAAAA2gsBBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAOAQUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHEIKAAAwDgEFAAAYBwCCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADAOAQUAABgHAIKAAAwToB4Icuy9L6ystLuqgAAgCZyfW67Psc7XEA5ffq03ickJNhdFQAA0ILP8YiIiAuWcVhNiTGGcTqdcuDAARkwYIAcPXpUwsPD7a6SV6dZFfRox9ajLdsObdk2aMe2Q1u2DRU5VDiJj48XPz+/jteDot7UZZddpr9WPyj8sLQe7dh2aMu2Q1u2Ddqx7dCWrXexnhMXJskCAADjEFAAAIBxvDagBAcHy/z58/UeLUc7th3asu3Qlm2Ddmw7tOWl55WTZAEAQMfmtT0oAACg4yKgAAAA4xBQAACAcQgoAADAOF4ZUJYsWSKXX365dOrUSZKTk2Xnzp12V8k4W7dulfHjx+ur9TkcDlmzZo3HeTU3et68edK9e3cJCQmR1NRU+eqrrzzKnDp1SiZOnKgvShQZGSmTJ0+Wqqoq8SU5OTkyYsQI6dKli8TExMiECRP0VYwbOnv2rGRmZkp0dLSEhYVJenq6lJaWepQ5cuSIjBs3TkJDQ/XzzJ49W+rq6sRX5ObmSlJSkvsiVykpKbJhwwb3edqw5RYtWqT/jc+cOdP9GO3ZNE8++aRuu4Zbv3793OdpR5tZXmblypVWUFCQ9corr1hFRUXWlClTrMjISKu0tNTuqhnl3XfftX77299ab7/9tlqlZa1evdrj/KJFi6yIiAhrzZo11r/+9S/rF7/4hdWrVy/r+++/d5e59dZbrSFDhlgff/yx9c9//tPq3bu3dc8991i+JC0tzXr11Vetffv2WXv27LHGjh1rJSYmWlVVVe4yDz74oJWQkGDl5+dbn3zyiTVy5EjrZz/7mft8XV2dNWjQICs1NdXavXu3/n/TtWtXKzs72/IV//jHP6x33nnH+vLLL60DBw5Yjz/+uBUYGKjbVaENW2bnzp3W5ZdfbiUlJVkPPfSQ+3Has2nmz59vDRw40CouLnZvJ06ccJ+nHe3ldQHlmmuusTIzM93H9fX1Vnx8vJWTk2NrvUz244DidDqtuLg465lnnnE/Vl5ebgUHB1uvv/66Pv7888/19+3atctdZsOGDZbD4bD+85//WL6qrKxMt8uWLVvc7aY+aFetWuUu88UXX+gyBQUF+lj90vLz87NKSkrcZXJzc63w8HCrurra8lU/+clPrD//+c+0YQudPn3a6tOnj5WXl2fdeOON7oBCezYvoKg/whpDO9rPq4Z4ampqpLCwUA9HNLwvjzouKCiwtW7e5PDhw1JSUuLRjureCGq4zNWOaq+GdYYPH+4uo8qr9t6xY4f4qoqKCr2PiorSe/XzWFtb69GWqos4MTHRoy0HDx4ssbGx7jJpaWn65mNFRUXia+rr62XlypVy5swZPdRDG7aMGnpQQwsN202hPZtHDW2rofArrrhCD2mrIRuFdrSfV90s8Ntvv9W/3Br+MCjqeP/+/bbVy9uocKI01o6uc2qvxlMbCggI0B/MrjK+Rt1FW43zX3vttTJo0CD9mGqLoKAgHeYu1JaNtbXrnK/Yu3evDiRqXF+N569evVrfkXzPnj20YTOpgPfpp5/Krl27zjnHz2TTqT/Kli1bJn379pXi4mJZsGCBXH/99bJv3z7a0QBeFVAAu/9iVb+4tm3bZndVvJL6EFBhRPVCvfXWW5KRkSFbtmyxu1pe5+jRo/LQQw9JXl6eXiiAlhszZoz7azWJWwWWnj17yptvvqkXD8BeXjXE07VrV/H39z9nFrU6jouLs61e3sbVVhdqR7UvKyvzOK9mpquVPb7Y1tOnT5f169fLhx9+KD169HA/rtpCDT2Wl5dfsC0ba2vXOV+h/hrt3bu3DBs2TK+OGjJkiDz//PO0YTOpoQf1b/OnP/2p7tVUmwp6L7zwgv5a/QVPe7aM6i256qqr5ODBg/xcGsDP237BqV9u+fn5Ht3u6lh1HaNpevXqpf/xNGxHNWaq5pa42lHt1T9M9cvQZdOmTbq91V8ZvkLNMVbhRA1HqPev2q4h9fMYGBjo0ZZqGbIax27Ylmp4o2HgU3/9quW2aojDV6mfperqatqwmUaNGqXbQvVGuTY1V0zNn3B9TXu2jLqMwqFDh/TlF/i5NIDlhcuM1WqTZcuW6ZUmU6dO1cuMG86ixg8z/NWyN7Wp/81//OMf9dfffPONe5mxare1a9dan332mXX77bc3usz46quvtnbs2GFt27ZNrxjwtWXG06ZN08uxN2/e7LEU8bvvvvNYiqiWHm/atEkvRUxJSdHbj5cijh49Wi9V3rhxo9WtWzefWor42GOP6ZVPhw8f1j9v6litCHv//ff1edqwdRqu4lFoz6Z5+OGH9b9t9XP50Ucf6eXCapmwWq2n0I728rqAorz44ov6h0ZdD0UtO1bX6YCnDz/8UAeTH28ZGRnupcZPPPGEFRsbqwPfqFGj9PUpGjp58qQOJGFhYXrZ3AMPPKCDjy9prA3Vpq6N4qJC3W9+8xu9bDY0NNT65S9/qUNMQ19//bU1ZswYKyQkRP8CVL8Ya2trLV/x61//2urZs6f+N6t+gaufN1c4UWjDtg0otGfT3HXXXVb37t31z+Vll12mjw8ePOg+Tzvay6H+Y3cvDgAAgNfOQQEAAL6BgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHEIKAAAwDgEFAAAYBwCCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAAMc3/AVHIBTtA01b0AAAAAElFTkSuQmCC"
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "execution_count": 14
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:44:06.397059Z",
     "start_time": "2025-09-02T02:44:06.394410Z"
    }
   },
   "cell_type": "code",
   "source": [
    "#测试游戏环境\n",
    "def test_env():\n",
    "    state = env.reset()\n",
    "    print('这个游戏的状态用4个数字表示,我也不知道这4个数字分别是什么意思,反正这4个数字就能描述游戏全部的状态')\n",
    "    print('state=', state)\n",
    "    #state= [ 0.03490619  0.04873464  0.04908862 -0.00375859]\n",
    "\n",
    "    print('这个游戏一共有2个动作,不是0就是1')\n",
    "    print('env.action_space=', env.action_space)\n",
    "    #env.action_space= Discrete(2)\n",
    "\n",
    "    print('随机一个动作')\n",
    "    action = env.action_space.sample()\n",
    "    print('action=', action)\n",
    "    #action= 1\n",
    "\n",
    "    print('执行一个动作,得到下一个状态,奖励,是否结束')\n",
    "    state, reward, over, _ = env.step(action)\n",
    "\n",
    "    print('state=', state)\n",
    "    #state= [ 0.02018229 -0.16441101  0.01547085  0.2661691 ]\n",
    "\n",
    "    print('reward=', reward)\n",
    "    #reward= 1.0\n",
    "\n",
    "    print('over=', over)\n",
    "    #over= False\n",
    "\n",
    "\n",
    "test_env()"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "这个游戏的状态用4个数字表示,我也不知道这4个数字分别是什么意思,反正这4个数字就能描述游戏全部的状态\n",
      "state= [-0.02250056  0.04981454  0.01160632  0.01437772]\n",
      "这个游戏一共有2个动作,不是0就是1\n",
      "env.action_space= Discrete(2)\n",
      "随机一个动作\n",
      "action= 0\n",
      "执行一个动作,得到下一个状态,奖励,是否结束\n",
      "state= [-0.02150427 -0.14547193  0.01189387  0.31069985]\n",
      "reward= 1.0\n",
      "over= False\n"
     ]
    }
   ],
   "execution_count": 15
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:44:06.419723Z",
     "start_time": "2025-09-02T02:44:06.416125Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import torch\n",
    "\n",
    "#计算动作的模型,也是真正要用的模型\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 2),\n",
    ")\n",
    "\n",
    "\n",
    "\n",
    "model"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Linear(in_features=4, out_features=128, bias=True)\n",
       "  (1): ReLU()\n",
       "  (2): Linear(in_features=128, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 16
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:44:06.435576Z",
     "start_time": "2025-09-02T02:44:06.432438Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import random\n",
    "\n",
    "\n",
    "#得到一个动作\n",
    "def get_action(state):\n",
    "    if random.random() < 0.01:\n",
    "        return random.choice([0, 1])\n",
    "\n",
    "    #走神经网络,得到一个动作\n",
    "    state = torch.FloatTensor(state).reshape(1, 4)\n",
    "\n",
    "    return model(state).argmax().item()\n",
    "\n",
    "\n",
    "get_action([0.0013847, -0.01194451, 0.04260966, 0.00688801])"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 17
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:44:06.464797Z",
     "start_time": "2025-09-02T02:44:06.455577Z"
    }
   },
   "cell_type": "code",
   "source": [
    "#样本池\n",
    "datas = []\n",
    "\n",
    "\n",
    "#向样本池中添加N条数据,删除M条最古老的数据\n",
    "def update_data():\n",
    "    old_count = len(datas)\n",
    "\n",
    "    #玩到新增了N个数据为止\n",
    "    while len(datas) - old_count < 200:\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action = get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            next_state, reward, over, _ = env.step(action)\n",
    "\n",
    "            #记录数据样本\n",
    "            datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "            #更新游戏状态,开始下一个动作\n",
    "            state = next_state\n",
    "\n",
    "    update_count = len(datas) - old_count\n",
    "    drop_count = max(len(datas) - 10000, 0)\n",
    "\n",
    "    #数据上限,超出时从最古老的开始删除\n",
    "    while len(datas) > 10000:\n",
    "        datas.pop(0)\n",
    "\n",
    "    return update_count, drop_count\n",
    "\n",
    "\n",
    "update_data(), len(datas)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((211, 0), 211)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 18
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:44:06.478344Z",
     "start_time": "2025-09-02T02:44:06.473956Z"
    }
   },
   "cell_type": "code",
   "source": [
    "#获取一批数据样本\n",
    "def get_sample():\n",
    "    #从样本池中采样\n",
    "    samples = random.sample(datas, 64)\n",
    "\n",
    "    #[b, 4]\n",
    "    state = torch.FloatTensor([i[0] for i in samples])\n",
    "    #[b]\n",
    "    action = torch.LongTensor([i[1] for i in samples])\n",
    "    #[b]\n",
    "    reward = torch.FloatTensor([i[2] for i in samples])\n",
    "    #[b, 4]\n",
    "    next_state = torch.FloatTensor([i[3] for i in samples])\n",
    "    #[b]\n",
    "    over = torch.LongTensor([i[4] for i in samples])\n",
    "\n",
    "    return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "state[:5], action, reward, next_state[:5], over"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.0381, -0.1637,  0.0088,  0.3453],\n",
       "         [-0.0692, -0.0475,  0.0502,  0.0755],\n",
       "         [-0.0472, -0.2393,  0.0228,  0.2936],\n",
       "         [-0.0011, -0.1488,  0.0175,  0.2773],\n",
       "         [-0.0665, -0.1785,  0.1309,  0.6769]]),\n",
       " tensor([1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0,\n",
       "         1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0,\n",
       "         1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0]),\n",
       " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "         1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),\n",
       " tensor([[-0.0414,  0.0313,  0.0157,  0.0554],\n",
       "         [-0.0701, -0.2433,  0.0517,  0.3836],\n",
       "         [-0.0520, -0.0445,  0.0287,  0.0082],\n",
       "         [-0.0040,  0.0461,  0.0231, -0.0098],\n",
       "         [-0.0700, -0.3752,  0.1444,  1.0077]]),\n",
       " tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 19
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:44:06.496330Z",
     "start_time": "2025-09-02T02:44:06.490713Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def get_value(state, action):\n",
    "    #使用状态计算出动作的logits\n",
    "    #[b, 4] -> [b, 2]\n",
    "    value = model(state)\n",
    "\n",
    "    #根据实际使用的action取出每一个值\n",
    "    #这个值就是模型评估的在该状态下,执行动作的分数\n",
    "    #在执行动作前,显然并不知道会得到的反馈和next_state\n",
    "    #所以这里不能也不需要考虑next_state和reward\n",
    "    #[b, 2] -> [b]\n",
    "    value = value[range(64), action]\n",
    "\n",
    "    return value\n",
    "\n",
    "\n",
    "get_value(state, action)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0680, 0.0710, 0.0931, 0.0634, 0.0637, 0.0708, 0.0721, 0.0799, 0.0630,\n",
       "        0.0725, 0.0632, 0.0572, 0.0774, 0.0682, 0.0749, 0.0675, 0.0737, 0.0670,\n",
       "        0.0668, 0.0631, 0.0708, 0.0688, 0.0704, 0.0707, 0.0699, 0.0658, 0.0626,\n",
       "        0.0648, 0.0664, 0.0713, 0.0725, 0.0697, 0.0707, 0.0698, 0.0680, 0.0827,\n",
       "        0.0732, 0.0693, 0.0694, 0.0782, 0.0703, 0.0708, 0.0688, 0.0718, 0.0611,\n",
       "        0.0748, 0.0718, 0.0630, 0.0677, 0.0721, 0.0699, 0.0599, 0.0902, 0.0651,\n",
       "        0.0682, 0.0694, 0.0704, 0.0671, 0.0589, 0.0780, 0.0640, 0.0892, 0.0606,\n",
       "        0.0733], grad_fn=<IndexBackward0>)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 20
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:44:06.531198Z",
     "start_time": "2025-09-02T02:44:06.518058Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def get_target(reward, next_state, over):\n",
    "    #上面已经把模型认为的状态下执行动作的分数给评估出来了\n",
    "    #下面使用next_state和reward计算真实的分数\n",
    "    #针对一个状态,它到底应该多少分,可以使用以往模型积累的经验评估\n",
    "    #这也是没办法的办法,因为显然没有精确解,这里使用延迟更新的next_model评估\n",
    "\n",
    "    #使用next_state计算下一个状态的分数\n",
    "    #[b, 4] -> [b, 2]\n",
    "    with torch.no_grad():\n",
    "        target = model(next_state)\n",
    "\n",
    "    #取所有动作中分数最大的\n",
    "    #[b, 2] -> [b]\n",
    "    target = target.max(dim=1)[0]\n",
    "\n",
    "    #如果next_state已经游戏结束,则next_state的分数是0\n",
    "    #因为如果下一步已经游戏结束,显然不需要再继续玩下去,也就不需要考虑next_state了.\n",
    "    #[b]\n",
    "    for i in range(64):\n",
    "        if over[i]:\n",
    "            target[i] = 0\n",
    "\n",
    "    #下一个状态的分数乘以一个系数,相当于权重\n",
    "    #[b] * [b] -> [b]\n",
    "    target *= 0.98\n",
    "\n",
    "    #加上reward就是最终的分数\n",
    "    #[b] + [b] -> [b]\n",
    "    target += reward\n",
    "\n",
    "    return target\n",
    "\n",
    "\n",
    "get_target(reward, next_state, over)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1.0703, 1.0892, 1.0700, 1.0698, 1.0654, 1.0659, 1.0697, 1.0618, 1.0834,\n",
       "        1.0745, 1.0697, 1.0653, 1.0781, 1.0609, 1.0613, 1.0600, 1.0852, 1.0718,\n",
       "        1.0626, 1.0000, 1.0617, 1.0687, 1.0622, 1.0657, 1.0705, 1.0609, 1.0727,\n",
       "        1.0742, 1.0618, 1.0677, 1.0630, 1.0699, 1.0884, 1.0587, 1.0783, 1.0760,\n",
       "        1.0693, 1.0694, 1.0561, 1.0711, 1.0903, 1.0911, 1.0638, 1.0874, 1.0648,\n",
       "        1.0617, 1.0666, 1.0644, 1.0697, 1.0558, 1.0653, 1.0619, 1.0703, 1.0646,\n",
       "        1.0636, 1.0888, 1.0907, 1.0685, 1.0706, 1.0715, 1.0583, 1.0722, 1.0646,\n",
       "        1.0710])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 21
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:44:06.545364Z",
     "start_time": "2025-09-02T02:44:06.541697Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step(action)\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play:\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "33.0"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 22
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:44:49.081764Z",
     "start_time": "2025-09-02T02:44:06.557762Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(500):\n",
    "        #更新N条数据\n",
    "        update_count, drop_count = update_data()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "            #计算一批样本的value和target\n",
    "            value = get_value(state, action)\n",
    "            target = get_target(reward, next_state, over)\n",
    "\n",
    "            #更新参数\n",
    "            loss = loss_fn(value, target)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        if epoch % 50 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(20)]) / 20\n",
    "            print(epoch, len(datas), update_count, drop_count, test_result)\n",
    "\n",
    "\n",
    "train()"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 431 220 0 9.4\n",
      "50 10000 389 389 181.25\n",
      "100 10000 200 200 200.0\n",
      "150 10000 200 200 190.25\n",
      "200 10000 200 200 195.9\n",
      "250 10000 200 200 196.95\n",
      "300 10000 200 200 200.0\n",
      "350 10000 392 392 194.85\n",
      "400 10000 200 200 200.0\n",
      "450 10000 390 390 171.5\n"
     ]
    }
   ],
   "execution_count": 23
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "test(play=True)\n",
    "\n"
   ],
   "execution_count": 24,
   "outputs": [
    {
     "data": {
      "text/plain": [
       "200.0"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
